Skip to content

[Common/PyTorch] Grouped-quantize kernels for 1D and 2D FP8 block-scaling#3135

Merged
denera merged 1 commit into
NVIDIA:mainfrom
denera:common/fp8-block-scaling-grouped-quantize
Jul 1, 2026
Merged

[Common/PyTorch] Grouped-quantize kernels for 1D and 2D FP8 block-scaling#3135
denera merged 1 commit into
NVIDIA:mainfrom
denera:common/fp8-block-scaling-grouped-quantize

Conversation

@denera

@denera denera commented Jun 17, 2026

Copy link
Copy Markdown
Collaborator

Description

Implements grouped-tensor quantize for the FP8 1D (1x128) and 2D (128x128) block-scaling recipes in row-wise (RW), column-wise (CW) and BOTH quantization directions. A single CUDA kernel launch walks 128x128 tiles across every tensor in the group, with each CTA decoding its owning tensor from the device-side GroupedTensor metadata with (N, R, K) shapes. Supports SAME_BOTH_DIMS (all tensors identical) and VARYING_FIRST_DIM (constant K, varying R) shape representations.

Three kernels share the dispatcher in group_quantize_blockwise_{1d,2d}:

  • group_block_scaled_1d_rw_kernel — RW-only dispatch; 8 threads/row, reads global memory directly into vec-16 registers; bypasses TMA because the shared memory roundtrip and ptx::mbarrier does not buy anything without re-use in CW path.
  • group_block_scaled_1d_tma_kernel — CW-only and BOTH dispatch; TMA bulk-load fills shared memory input cache. BOTH runs RW pass first (8 threads/row, vec-16 read from shared memory) then CW pass (2 threads/column, 64-row register stage); CW-only skips the RW pass. CW path writes the transposed-FP8 tile to a shared memory transpose staging buffer, then drains to global memory.
  • group_block_scaled_2d_tma_kernel — RW-only, CW-only and BOTH dispatch; TMA bulk-load fills shared memory input cache. Pass 1 stages 8 IVecs/thread in registers while computing the per-tile scalar amax. Pass 2 quantizes from registers, emits row-wise output, stages column-wise output to shared memory transpose staging buffer, then drains to global memory.

Kernels are gated to Hopper (sm_90) at the host dispatcher (cuBlasLt grouped GEMM supports FP8 block-scaling only on Hopper).

PR includes PyTorch integration into te.GroupedTensor only.

PyTorch integration into te.GroupedLinear and JAX integration deferred to a follow-up PRs.

Partially resolves #2525

Performance

Benchmark on H200 with a sweep of grouped tensors in (N, M, K) shapes:

  • N ∈ {4, 8, 16, 32, 64, 128} (# of device-local experts)
  • M = 4096 @ N = 4 → M = 128 @ N = 128 (# of tokens/expert, scaling inversely with # experts)
  • K ∈ {1024, 1792, 2048, 3584, 4096, 7168} (device-local shard of TP-hidden/intermediate-FFN dim)

Two shape families per config:

  • U_MoE (uniform, SAME_BOTH_DIMS): all experts share the (M, K) shape
  • J_MoE (jagged, VARYING_FIRST_DIM): per-expert M drawn from an imbalanced routing, common K

Buckets:

  • Small/Unsaturated (S): R·K ≤ 32M elements (< 2048 tiles and < 15 waves on H200's 132 SMs)
  • Large/Saturated (L): R·K > 32M elements (> 2048 tiles, SMs busy across many waves)

Bucket medians across 3 reps. Speedup is grouped vs the split-quantized fallback that loops over the grouped tensor and quantizes each constituent sequentially. % mono is grouped throughput relative to a single non-grouped FP8 block-scaling quantize on the equivalent monolithic (N·M, K) tensor.

Bucket Path Grouped (ms) Split (ms) Speedup % memcpy tput % mono tput
S 1D RW 0.019 0.084 4.50× 64.9 % 114.9 %
S 1D CW 0.022 0.090 4.17× 56.4 % 112.4 %
S 1D BOTH 0.033 0.116 3.57× 50.2 % 102.8 %
S 2D RW 0.018 0.076 4.19× 65.9 % 100.9 %
S 2D CW 0.020 0.089 4.64× 62.6 % 126.3 %
S 2D BOTH 0.027 0.089 3.68× 66.3 % 98.8 %
L 1D RW 0.058 0.198 2.06× 87.1 % 118.7 %
L 1D CW 0.064 0.213 2.05× 77.3 % 118.5 %
L 1D BOTH 0.098 0.282 1.77× 66.7 % 108.4 %
L 2D RW 0.054 0.178 1.99× 87.7 % 100.1 %
L 2D CW 0.058 0.213 2.20× 85.3 % 135.9 %
L 2D BOTH 0.078 0.213 1.62× 85.0 % 102.4 %
# experts (N) S bucket L bucket
4 1.74× 1.42×
8 2.40× 1.45×
16 4.18× 1.89×
32 5.50× 2.81×
64 10.43× 7.51×
128 19.81× 8.72×

Notes

  • % of mono throughput is roughly consistent across buckets for every path, confirming no per-expert overhead in the new kernels.
  • Greater-than-100% mono throughput cases come from TMA bulk-loads, register staging, and vec-16 reads that the non-grouped FP8 block-scaling kernels do not have.
  • Speedup over split-quantize scales as expected with # of experts (roughly linearly in the unsaturated regime).
  • S-bucket % memcpy is lower than L because launch and per-CTA setup are not amortized over a long bandwidth-bound steady state; absolute kernel times are still small (< 35 µs).

Known Sub-Optimalities

1D CW load bank conflicts on ~35% of load wavefronts (reading from the shared memory input-cache)

  • No possible stride padding or XOR swizzle to alleviate.
  • TMA hardware swizzle with CU_TENSOR_MAP_SWIZZLE_128B has the right pattern but caps FP16/BF16 at 64-elements; does not fit the 128-element tile for FP8 block-scaling without doubling per-tile launch overhead (quadrupling for FP32).
  • CW-only stays at 2 t/col. Going to 4 t/col would double the bank-conflict footprint (4 lanes per column at the same row stride instead of 2), and CW-only is not occupancy-bound (3 CTAs/SM regardless), so the restructure costs more than it saves.
  • 1D BOTH uses 4 t/col with a 32-row reg_data per thread and two column passes per CTA. The RW pass's per-expert scale-offset arithmetic plus a 64-row reg_data crossed the 85-reg / 3-CTA threshold on sm_90; halving reg_data restores 4 CTAs/SM. The doubled column pass and extra XOR-reduce stage are cheap relative to the occupancy gain.

1D BOTH reads the shared memory input-cache twice

  • The RW (8 threads/row) and CW (2 threads/column) passes have different threading.
  • Attempted to unify with 8 threads/row for both RW and CW. Caused bank conflicts on ~76% of store wavefronts (writing to the shared memory transpose buffer), reduced to ~43% with a XOR swizzle but not enough to beat separate RW/CW passes.
  • Did not pursue the 2 threads/column unification; costs 40x more shfl ops than 8 threads/row attempt, plus a shared memory partial buffer and sync.

2D CW/BOTH has bank conflicts on ~16% of store wavefronts (when writing to the shared memory transpose buffer)

  • Already reduced from ~75% via a XOR swizzle, further reduction was not possible.
  • Minimal impact (< 5%) on kernel time.

No TMA-store

  • MXFP8 grouped quantize kernel leverages this by decomposing a 128x128 tile into 32-row sub-stages that each have their own independent 32x1 or 1x32 scale; shared memory footprint is based on a single sub-stage; can be quantized and TMA-stored independently; hides TMA-store of one stage under the compute of next stage.
  • FP8 block-scaling 128-element scale-block spans the entire 128-row tile. Cannot decompose into independent sub-stages and pipeline the TMA-stores. Single non-pipelined TMA-store requires holding the transposed staging buffer for the entire tile until all work on tile is finished, blows up shared memory footprint, collapses occupancy to 2CTA/SM. The recipe itself is the roadblock.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@denera denera requested review from ptrendx and vthumbe1503 June 17, 2026 13:01
@denera denera self-assigned this Jun 17, 2026
@denera denera added performance Performance issues FP8 MoE labels Jun 17, 2026
Comment thread transformer_engine/common/cast/fp8_blockwise/group_quantize_fp8_blockwise.cuh Outdated
Comment thread transformer_engine/common/cast/fp8_blockwise/group_quantize_fp8_blockwise.cuh Outdated
Comment thread transformer_engine/common/cast/fp8_blockwise/group_quantize_fp8_blockwise.cuh Outdated
@greptile-apps

greptile-apps Bot commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

Adds fused grouped-tensor FP8 block-scaling quantize (1D and 2D) and dequantize kernels for Hopper (SM90-SM99), plus full PyTorch integration into te.GroupedTensor. A single CUDA launch walks 128×128 tiles across every tensor in the group, with each CTA decoding its owning expert from device-side GroupedTensor metadata.

  • group_quantize_fp8_blockwise.cuh and group_dequantize_fp8_blockwise.cuh: new SM90-gated kernels with per-expert compact scale layouts matching cuBLAS grouped GEMM's expectation; TMA bulk-load (shared::cta scope) is used for the CW and BOTH paths.
  • quantizer.cpp / cast.cpp: Float8BlockQuantizer::create_grouped_tensor now sizes per-expert scale buffers correctly, rejects force_pow_2_scales=True, and sets with_gemm_swizzled_scales=false.
  • ptx.cuh: mbarrier/TMA helpers lowered from SM100 to SM90 guard; new cp_async_bulk_tensor_2d_global_to_shared_cta added for the shared::cta TMA variant needed by the Hopper-targeted kernels.

Confidence Score: 4/5

The new CUDA kernels and C++ dispatchers are well-constructed; the two defects are both in the Python test helper, not in the production path.

The make_quantizer helper in test_grouped_tensor.py sets force_pow_2_scales=True for the fp8_blockwise case, while Float8BlockQuantizer::create_grouped_tensor now explicitly rejects that flag — every test parameterized with _quantization_params[fp8_blockwise] that calls grouped_tensor.quantize() will raise NVTE_ERROR on Hopper before any kernel fires. Separately, the skip-condition for that param entry uses fp8_block_scaling_available rather than fp8_block_scaling_grouped_available, allowing the tests to run on Blackwell where the SM90-SM99 guard would immediately reject them.

tests/pytorch/test_grouped_tensor.py — the make_quantizer helper and the _quantization_params skip condition for fp8_blockwise both need correction.

Important Files Changed

Filename Overview
transformer_engine/common/cast/fp8_blockwise/group_quantize_fp8_blockwise.cuh New 1022-line file implementing 1D/2D block-scaling grouped quantize kernels (RW, CW, BOTH) with SM90-guard, TMA usage, and per-expert scale layout helpers.
transformer_engine/common/cast/fp8_blockwise/group_dequantize_fp8_blockwise.cuh New 519-line file implementing dequantize kernels for all four {1D,2D}x{RW,CW} combinations, correctly mirroring the quantize scale layouts.
tests/pytorch/test_grouped_tensor.py New test file: make_quantizer for fp8_blockwise uses force_pow_2_scales=True (rejected by new NVTE_CHECK) and the _quantization_params skip condition uses the wrong availability flag.
transformer_engine/pytorch/csrc/quantizer.cpp Float8BlockQuantizer::create_grouped_tensor updated: rejects force_pow_2_scales, correctly sizes per-expert scale buffers including 2D CW slack, sets with_gemm_swizzled_scales=false.
transformer_engine/pytorch/csrc/extensions/cast.cpp Adds FP8_BLOCKWISE_GROUPED_QUANTIZE dispatch in group_quantize and bgrad_group_quantize; updates scale_dtype for block-scaling in group_dequantize.
transformer_engine/common/util/ptx.cuh Lowers mbarrier/TMA PTX helpers from SM100 to SM90 guard; adds cp_async_bulk_tensor_2d_global_to_shared_cta for shared::cta scope TMA.
tests/cpp/operator/test_cast_float8blockwise_grouped.cu New C++ test exercising all {1D,2D}x{RW,CW,BOTH}x{SAME_BOTH_DIMS,VARYING_FIRST_DIM} combinations against split-quantize reference; correctly uses force_pow_2_scales=false.
transformer_engine/common/cast/dispatch/quantize.cuh Adds NVTE_BLOCK_SCALING_1D/2D cases to group_quantize_fwd_helper and group_quantize_bwd_helper with correct IS_ACT/force_pow_2_scales guards.
transformer_engine/common/cast/dispatch/dequantize.cuh Adds NVTE_BLOCK_SCALING_1D/2D case to group_dequantize_helper delegating to fp8_blockwise::group_dequantize.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Python group_quantize] --> B{quantizer type?}
    B -->|Float8BlockwiseQuantizer| C[FP8_BLOCKWISE_GROUPED_QUANTIZE]
    B -->|MXFP8| D[MXFP8 path]
    B -->|NVFP4| E[NVFP4 path]
    C --> G[create_grouped_tensor - NVTE_CHECK force_pow_2_scales==false]
    G --> H{scaling_mode?}
    H -->|BLOCK_SCALING_1D| I[group_quantize_blockwise_1d]
    H -->|BLOCK_SCALING_2D| J[group_quantize_blockwise_2d]
    I --> K{RW-only no dbias?}
    K -->|yes| L[group_block_scaled_1d_rw_kernel - no smem]
    K -->|no| M[group_block_scaled_1d_tma_kernel - TMA CW/BOTH]
    J --> N[group_block_scaled_2d_tma_kernel - TMA pass1 amax pass2 quant]
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A[Python group_quantize] --> B{quantizer type?}
    B -->|Float8BlockwiseQuantizer| C[FP8_BLOCKWISE_GROUPED_QUANTIZE]
    B -->|MXFP8| D[MXFP8 path]
    B -->|NVFP4| E[NVFP4 path]
    C --> G[create_grouped_tensor - NVTE_CHECK force_pow_2_scales==false]
    G --> H{scaling_mode?}
    H -->|BLOCK_SCALING_1D| I[group_quantize_blockwise_1d]
    H -->|BLOCK_SCALING_2D| J[group_quantize_blockwise_2d]
    I --> K{RW-only no dbias?}
    K -->|yes| L[group_block_scaled_1d_rw_kernel - no smem]
    K -->|no| M[group_block_scaled_1d_tma_kernel - TMA CW/BOTH]
    J --> N[group_block_scaled_2d_tma_kernel - TMA pass1 amax pass2 quant]
Loading

Comments Outside Diff (1)

  1. tests/pytorch/test_grouped_tensor.py, line 92-100 (link)

    P1 force_pow_2_scales=True makes every parameterized grouped-quantize test fail on Hopper

    Float8BlockQuantizer::create_grouped_tensor now contains an explicit NVTE_CHECK(!force_pow_2_scales, "Fused grouped FP8 block-scaling quantize does not support force_pow_2_scales=True"). That check fires before any kernel is launched — it is inside create_grouped_tensor, which is called from group_quantize in cast.cpp when building grouped_output_tensor_cpp. Concretely, test_quantize_inplace["fp8_blockwise"] and test_quantize_varying_shapes["fp8_blockwise"] both call grouped_tensor.quantize(input_tensors)group_quantizeFloat8BlockQuantizer::create_grouped_tensor, and will raise NVTE_ERROR on any Hopper machine. The C++ test (test_cast_float8blockwise_grouped.cu, line 377) already uses force_pow_2_scales=false as the correct baseline; the Python helper should match it.

Reviews (15): Last reviewed commit: "[Common/PyTorch] Grouped-quantize kernel..." | Re-trigger Greptile

Comment thread tests/cpp/operator/test_cast_float8blockwise_grouped.cu
Comment thread transformer_engine/common/cast/fp8_blockwise/group_quantize_fp8_blockwise.cuh Outdated
Comment thread transformer_engine/common/cast/fp8_blockwise/group_quantize_fp8_blockwise.cuh Outdated
Comment thread transformer_engine/common/cast/fp8_blockwise/group_quantize_fp8_blockwise.cuh Outdated
Comment thread transformer_engine/common/cast/fp8_blockwise/group_quantize_fp8_blockwise.cuh Outdated
Comment thread transformer_engine/common/cast/fp8_blockwise/group_quantize_fp8_blockwise.cuh Outdated
Comment thread transformer_engine/common/cast/fp8_blockwise/group_quantize_fp8_blockwise.cuh Outdated
Comment thread transformer_engine/common/cast/fp8_blockwise/group_quantize_fp8_blockwise.cuh Outdated
Comment thread transformer_engine/common/cast/fp8_blockwise/group_quantize_fp8_blockwise.cuh Outdated
Comment thread transformer_engine/common/cast/fp8_blockwise/group_quantize_fp8_blockwise.cuh Outdated
Comment thread transformer_engine/common/cast/fp8_blockwise/group_quantize_fp8_blockwise.cuh Outdated
denera added a commit to denera/TransformerEngine that referenced this pull request Jun 22, 2026
- Reuse shared helpers (DIVUP, DIVUP_TO_MULTIPLE, TMA_GMEM_ALIGNMENT,
  align_smem_ptr_per_TMA_requirements, get_current_tensor_id,
  subwarp_reduce_max_broadcast) in place of local equivalents.
- Add proxy-async fence after mbarrier_init in 2D + 1D TMA kernels.
- Enforce per-tensor first_dim % 128 device-side for VARYING_FIRST_DIM
  (matches MXFP8 grouped quantize behavior).
- Fix Hopper SM range wording in 1D dispatcher.
- Extend cpp tests to cover with_gemm_swizzled_scales path.

Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera requested a review from Oleg-Goncharov June 22, 2026 23:06
// num_tiles_X = DIVUP(total_row_blocks, TILE_DIM_X=4)
__device__ __forceinline__ size_t swizzled_colwise_scale_idx(size_t i, size_t j,
size_t total_row_blocks) {
using namespace transformer_engine::dispatch::mxfp8::swizzle;

@vthumbe1503 vthumbe1503 Jun 22, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should rename the namespace for swizzle...given that we use the same constants for mxfp8, nvfp4, fp8 block scaling

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Grouped GEMM doesn't read FP8 block-scales in swizzled format. It requires a compact per-expert format instead, so I stripped out all the swizzle changes out of this PR.

Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
Comment thread transformer_engine/pytorch/module/grouped_linear.py
@denera denera force-pushed the common/fp8-block-scaling-grouped-quantize branch 3 times, most recently from dc998bd to 6a25307 Compare June 29, 2026 13:54
@denera denera requested a review from vthumbe1503 June 29, 2026 13:57
@denera denera force-pushed the common/fp8-block-scaling-grouped-quantize branch 3 times, most recently from 1c17b49 to 20c98e5 Compare June 30, 2026 07:16
@denera

denera commented Jun 30, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci cpp pytorch

@denera denera force-pushed the common/fp8-block-scaling-grouped-quantize branch from 20c98e5 to ab816b5 Compare June 30, 2026 15:41
@denera denera added the 2.17 label Jun 30, 2026
@denera denera force-pushed the common/fp8-block-scaling-grouped-quantize branch from ab816b5 to 7b115f5 Compare June 30, 2026 18:20
@denera

denera commented Jun 30, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci core pytorch

@vthumbe1503 vthumbe1503 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! We should probably move group dequantize to a different PR or add a test for it in this PR.

}
case NVTE_BLOCK_SCALING_1D:
case NVTE_BLOCK_SCALING_2D: {
fp8_blockwise::group_dequantize(&input, output, stream);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would need group dequantize tests as well for fp8 blockwise.
Alternatively it can be seperated out into a different PR, since group dequant is not a priority for GroupedLinear integration

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR already includes test_grouped_tensor.py::test_group_dequantize_fp8_blockwise so the dequantize is tested, but we're missing test_group_dequantize_cudagraph_capturable. I'll add it for parity with grouped MXFP8.

@vthumbe1503 vthumbe1503 Jul 1, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have CPP test in common though, considering JAX might also need it?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do now. Just pushed the commit with the new tests, both C++ and PyTorch, passing on H100.

No tests on the JAX side. The PR doesn't include any JAX changes. I punted that to a separate PR later this week because JAX needs a lot more changes to integrate FP8BS. Not quite as drop-in ready as PyTorch GroupedTensor.

NVTE_CHECK(info.tensor_offsets_d != nullptr,
"VARYING_FIRST_DIM requires tensor_offsets to be set on the GroupedTensor.");
}
info.total_row_blocks = DIVUP(info.R_total, static_cast<size_t>(kTileDim));

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to have static persistent kernel similar to mxfp8 in case of cuda graphs in future.

With cuda graphs total_rows can be larger than the total sum of first dims. And so we would be overlaunching thread blocks.

Not a blocker for the current PR. Persistent kernel based optimization can be a future PR if necessary.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll turn that into a separate PR. Would be a self-contained change with no framework integration impact so it should be something very minimal to review and merge later.

vthumbe1503
vthumbe1503 previously approved these changes Jul 1, 2026
@denera denera force-pushed the common/fp8-block-scaling-grouped-quantize branch from 7b115f5 to b00947e Compare July 1, 2026 16:34
…ling

Implements grouped-tensor quantize for the FP8 1D (1x128) and 2D (128x128)
block-scaling recipes in row-wise (RW), column-wise (CW) and BOTH quantization
directions. A single CUDA kernel launch walks 128x128 tiles across every tensor
in the group, with each CTA decoding its owning tensor from the device-side
GroupedTensor metadata with (N, R, K) shapes. Supports SAME_BOTH_DIMS (all
tensors identical) and VARYING_FIRST_DIM (constant K, varying R) shape
representations.

Three kernels share the dispatcher in group_quantize_blockwise_{1d,2d}:
- group_block_scaled_1d_rw_kernel: RW-only dispatch; 8 threads/row, reads
  global memory directly into vec-16 registers; bypasses TMA since the
  shared-memory roundtrip and ptx::mbarrier do not buy anything without
  re-use in the CW path.
- group_block_scaled_1d_tma_kernel: CW-only and BOTH dispatch. TMA bulk-load
  fills shared memory input cache. BOTH runs an RW pass (8 threads/row,
  vec-16 read from shared memory) then a CW pass; CW-only skips the RW
  pass. The CW pass uses 4 t/col with 32-row reg_data and two column passes
  in the BOTH instantiation (keeps the per-thread register footprint under
  the sm_90 3-CTAs/SM threshold) and 2 t/col with 64-row reg_data in the
  CW-only instantiation (avoids doubling the smem-load bank-conflict
  footprint that 4 t/col would introduce).
- group_block_scaled_2d_tma_kernel: RW-only, CW-only and BOTH dispatch. TMA
  bulk-load fills shared memory input cache. Pass 1 stages 8 IVecs/thread
  in registers while computing the per-tile scalar amax. Pass 2 quantizes
  from registers, emits row-wise output, stages column-wise output to the
  shared memory transpose staging buffer, then drains smem_T to global
  memory.

Per-expert scale offsets:
- 1D RW: closed-form O(1) for both SAME_BOTH_DIMS and VARYING_FIRST_DIM
  (each M_i is a multiple of kTileDim=128, hence of kScaleColAlign=4, so
  DIVUP_TO_MULTIPLE collapses and the prefix sum reduces to a single
  tensor_offsets_ptr[tensor_id]/K load).
- 2D CW: closed-form O(1) for SAME_BOTH_DIMS; CTA-cooperative warp-shuffle
  prefix sum for VARYING_FIRST_DIM (non-linear DIVUP_TO_MULTIPLE on
  blocks_y_t prevents a closed form). The cooperative reduction uses the
  existing warp_allreduce_sum helper from common/utils.cuh.

Dequantize and bias-gradient (bgrad):
- group_dequantize_fp8_blockwise.cuh: kernels for all four modes
  (1D/2D x rowwise/columnwise), inverting the per-expert layouts the
  quantize kernels write.
- bgrad_group_quantize accepts Float8Block quantizers and computes dbias
  per-tile column-partial in-kernel (mirroring MXFP8); reduced per expert
  via the existing common::grouped_reduce_dbias.

Scale constraints: the fused grouped FP8BS path supports only unconstrained
FP32 scales (Float8BlockQuantizer::create_grouped_tensor rejects
force_pow_2_scales=True). Power-of-2 scales remain available on the
non-grouped/unfused split-quantize path used for Blackwell MXFP8 emulation.

Tests: existing parametrized grouped quantize / dequantize / bgrad tests
in test_grouped_tensor.py cover MXFP8, NVFP4, FP8 current scaling and the
newly-added FP8 block scaling recipe. tests/cpp/operator/
test_cast_float8blockwise_grouped.cu adds 72 C++ unit-test cases over
uniform/jagged shapes, all four (BD x direction) modes, K in {128, 256,
512}, and CUDA-graph capture coverage.

Kernels are gated to Hopper (sm_90) at the host dispatcher (cuBlasLt
grouped GEMM supports FP8 block-scaling only on Hopper).

JAX integration is intentionally left out of scope and deferred to a
follow-up PR.

Resolves NVIDIA#2525

Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the common/fp8-block-scaling-grouped-quantize branch from b00947e to 82ea6a3 Compare July 1, 2026 16:50
@denera denera merged commit 9f2074e into NVIDIA:main Jul 1, 2026
10 of 14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Blockwise (1x128 and 128x128) FP8 grouped quantization

3 participants